import sys
import csv
import os
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
import numpy as np
import copy
import random
import pickle
def load_object(file_path):
    with open(file_path, 'rb') as file:
        obj = pickle.load(file)
    return obj

print(" Trining Embedding model for simulation agents ")
experiment_setting = input("Input experiment setting. (0 : standard setting, 1 : ablation of triplet loss, 2: ablation of recontruction loss) :")
subject_class = input("Input the target group. (A or B)")
subject_number = input("Input the simulated csv file name : (.csv)")
query= input("How many triplets included in the csv file? (1500, 3000, 4500, 6500, 7500) :")
assienment= input("Please select the target assignment for analysis. (1 or 2 or 3)")
ini_random_seed= input("Please input the random seed for model training.")

loaded_img_data = load_object('../train_img_128.pkl') # modify this path on your setting (You should <0. make image numpy file> first


if '.csv' not in subject_number:
    subject_number=subject_number+'.csv'

if subject_class=="a":
    subject_class ="A"
elif subject_class=="b":
    subject_class="B"


batch_size = 32
learning_rate = 0.0001
num_epochs = 5000

random.seed(ini_random_seed)
torch.manual_seed(ini_random_seed)

dtype = torch.float
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("mps") if torch.backends.mps.is_available() else torch.device("cpu")


if subject_class=="A":
    training_data_num_=list(range(11, 511))
    with open('./random_triplet_assignment/problemA'+str(assienment)+'.pkl', 'rb') as f:
        ele_test=pickle.load(f)
        ele_val=pickle.load(f)
else:
    training_data_num_=list(range(11+510, 511+510))
    with open('./random_triplet_assignment/problemB'+str(assienment)+'.pkl', 'rb') as f:
        ele_test=pickle.load(f)
        ele_val=pickle.load(f)



random.shuffle(training_data_num_)

training_data_num=[]
val_data_num=[]
test_data_num=[]
for el in training_data_num_:
    if el in ele_val:
        val_data_num.append(el)
    elif el in ele_test:
        test_data_num.append(el)
    else:
        training_data_num.append(el)



def reading_csv(subject_class, subject_number):
    num_of_task=int(int(query)/3)
    path = subject_number
    f = open(path, 'r')
    rdr = csv.reader(f)


    S_problem_data=[]

    rdr2=[]
    for line in rdr:
        rdr2.append(line)
    if subject_class=='A':
        for ii in range(num_of_task):

            for i, line in enumerate(rdr2):
                if i>0:
                    if int(line[1])==ii+11:
                        S_problem_data.append([[line[1]],[line[5],line[6],line[7]],[line[11],line[12],line[13]],["","",""]])

    if subject_class=='B':
        for ii in range(num_of_task):
            for i, line in enumerate(rdr2):
                if i>0:
                    if int(line[1])==ii+510+11:
                        S_problem_data.append([[line[1]],[line[5],line[6],line[7]],[line[11],line[12],line[13]],["","",""]])


    return S_problem_data


S_problem_data= reading_csv(subject_class, subject_number)


class Autoencoder(nn.Module):  #128
    def __init__(self):
        super(Autoencoder, self).__init__()
        self.conv1 = nn.Conv2d(1,16,3,2,1)
        self.relu1 = nn.ReLU()
        self.conv2 = nn.Conv2d(16, 32, 3, 2, 1)
        self.relu2 = nn.ReLU()
        self.conv3 = nn.Conv2d(32, 64, 3, 2, 1)
        self.relu3 = nn.ReLU()
        self.conv4 = nn.Conv2d(64, 128, 3, 2, 1)
        self.relu4 = nn.ReLU()

        self.fc1 = nn.Linear(128*8*8, 64)
        self.fc2= nn.Linear(64, 128*8*8)

        self.urelu1= nn.ReLU()
        self.upconv1=nn.ConvTranspose2d(128,64,4,2,1)
        self.urelu2 = nn.ReLU()
        self.upconv2 = nn.ConvTranspose2d(64, 32, 4, 2, 1)
        self.urelu3 = nn.ReLU()
        self.upconv3 = nn.ConvTranspose2d(32, 16, 4, 2, 1)
        self.urelu4 = nn.ReLU()
        self.upconv4 = nn.ConvTranspose2d(16, 1, 4, 2, 1)
        self.sig=nn.Sigmoid()

    def forward(self, x):
        x= self.conv1(x)
        x= self.relu1(x)
        x = self.conv2(x)
        x = self.relu2(x)
        x = self.conv3(x)
        x = self.relu3(x)
        x = self.conv4(x)
        x = self.relu4(x)
        x = torch.flatten(x, 1)
        emb= self.fc1(x)
        x=self.fc2(emb)
        x=self.urelu1(x)
        x = x.view(-1, 128, 8, 8)
        x=self.upconv1(x)
        x=self.urelu2(x)
        x=self.upconv2(x)
        x=self.urelu3(x)
        x=self.upconv3(x)
        x=self.urelu4(x)
        x=self.upconv4(x)
        x=self.sig(x)

        return x, emb


class CustomDataset(Dataset):
    def __init__(self, numpy_data, transform=None):
        self.data = torch.from_numpy(numpy_data).float()
        self.transform = transform

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        sample = self.data[idx]

        if self.transform:
            sample = self.transform(sample)

        return sample




class CustomDataset_new(Dataset):

    def __init__(self, conc_data, img_dic, transform=None):
        self.img_dic=img_dic
        self.conc = conc_data


        self.main_data1 = []
        self.main_data2 = []
        self.main_data3 = []
        for index, element in enumerate(conc_data):


            self.main_data1.append(img_dic[element[1][0]])
            if float(element[2][0]) <= float(element[2][2]):
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][2]])
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][1]])



            self.main_data1.append(img_dic[element[1][1]])
            if float(element[2][0]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][2]])
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][0]])


            self.main_data1.append(img_dic[element[1][2]])
            if float(element[2][2]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][1]])
            else:
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][0]])


        self.transform = transform


        self.main_data1=torch.from_numpy(np.array(self.main_data1))
        self.main_data2 = torch.from_numpy(np.array(self.main_data2))
        self.main_data3 = torch.from_numpy(np.array(self.main_data3))
    def __len__(self):
        return len(self.main_data1)

    def __getitem__(self, idx):

        sample1=self.main_data1[idx]
        sample2=self.main_data2[idx]
        sample3=self.main_data3[idx]

        if self.transform:
            sample1 = self.transform(sample1)
            sample2 = self.transform(sample2)
            sample3 = self.transform(sample3)


        return sample1,sample2,sample3

class CustomDataset_new2(Dataset):

    def __init__(self, conc_data, img_dic, transform=None):
        self.img_dic=img_dic  #
        self.conc = conc_data #


        self.main_data1 = []
        self.main_data2 = []
        self.main_data3 = []

        self.sim1=[] #
        self.sim2=[] #
        for index, element in enumerate(conc_data):


            self.main_data1.append(img_dic[element[1][0]])
            if float(element[2][0]) <= float(element[2][2]):
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][2]])
                self.sim1.append(float(element[2][0]))
                self.sim2.append(float(element[2][2]))
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][1]])
                self.sim1.append(float(element[2][2]))
                self.sim2.append(float(element[2][0]))


            self.main_data1.append(img_dic[element[1][1]])
            if float(element[2][0]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][2]])
                self.sim1.append(float(element[2][0]))
                self.sim2.append(float(element[2][1]))
            else:
                self.main_data2.append(img_dic[element[1][2]])
                self.main_data3.append(img_dic[element[1][0]])
                self.sim1.append(float(element[2][1]))
                self.sim2.append(float(element[2][0]))

            self.main_data1.append(img_dic[element[1][2]])
            if float(element[2][2]) <= float(element[2][1]):
                self.main_data2.append(img_dic[element[1][0]])
                self.main_data3.append(img_dic[element[1][1]])
                self.sim1.append(float(element[2][2]))
                self.sim2.append(float(element[2][1]))
            else:
                self.main_data2.append(img_dic[element[1][1]])
                self.main_data3.append(img_dic[element[1][0]])
                self.sim1.append(float(element[2][1]))
                self.sim2.append(float(element[2][2]))


        self.transform = transform
        self.main_data1=np.array(self.main_data1)
        self.main_data2=np.array(self.main_data2)
        self.main_data3=np.array(self.main_data3)
        self.sim1=np.array(self.sim1)
        self.sim2=np.array(self.sim2)

        self.main_data1=torch.from_numpy((self.main_data1))
        self.main_data2 = torch.from_numpy((self.main_data2))
        self.main_data3 = torch.from_numpy((self.main_data3))
        self.sim1=torch.from_numpy((self.sim1))
        self.sim2 = torch.from_numpy((self.sim2))
    def __len__(self):
        return len(self.main_data1)

    def __getitem__(self, idx):

        sample1=self.main_data1[idx]
        sample2=self.main_data2[idx]
        sample3=self.main_data3[idx]

        sample4=self.sim1[idx]
        sample5=self.sim2[idx]

        if self.transform:
            sample1 = self.transform(sample1)
            sample2 = self.transform(sample2)
            sample3 = self.transform(sample3)

        return sample1,sample2,sample3, sample4, sample5


class ToTensor:
    def __call__(self, sample):
        return torch.as_tensor(sample, dtype=torch.float32)

model = Autoencoder().to(device)
loaded_meta_data = S_problem_data


conc_training_data=[]
for a in loaded_meta_data:
    #print(a[0])
    if int(a[0][0]) in training_data_num:
        conc_training_data.append(a)

conc_val_data=[]
for a in loaded_meta_data:

    if int(a[0][0]) in val_data_num:
        conc_val_data.append(a)

conc_test_data=[]
for a in loaded_meta_data:

    if int(a[0][0]) in test_data_num:
        conc_test_data.append(a)


custom_dataset_training2 = CustomDataset_new2(conc_training_data, loaded_img_data, transform=ToTensor())
train_loader2 = DataLoader(custom_dataset_training2, batch_size=batch_size, shuffle=True)

custom_dataset_val = CustomDataset_new(conc_val_data, loaded_img_data, transform=ToTensor())
val_loader = DataLoader(custom_dataset_val, batch_size=batch_size, shuffle=True)

custom_dataset_test = CustomDataset_new(conc_test_data, loaded_img_data, transform=ToTensor())
test_loader = DataLoader(custom_dataset_test, batch_size=batch_size, shuffle=True)


def criterion4 (output_tensor, input_img, img1_emb, img2_emb, img3_emb, sim1, sim2, experiment_setting):
    squared_diff = (output_tensor - input_img) ** 2


    noise1 = (torch.randn(img2_emb.shape)*0.01).to(device)
    noise2 = (torch.randn(img3_emb.shape)*0.01).to(device)

    sim_diff_s = sim2*torch.mean((img1_emb - img2_emb.detach()+noise1) ** 2, dim=1)
    sim_diff_l = sim1*torch.mean((img1_emb - img3_emb.detach()+noise2) ** 2, dim=1)


    if int(experiment_setting) == 0: # starndatd setting
        mse_loss = 1.2 * torch.mean(sim_diff_s) + torch.mean(sim_diff_l) +torch.mean(squared_diff)
    elif int(experiment_setting) == 1: # ablation of triplet loss
        mse_loss =torch.mean(squared_diff)
    else : # ablation of recontruction loss
        mse_loss = 1.2*torch.mean(sim_diff_s) + torch.mean(sim_diff_l)  # 1.5

    return mse_loss

criterion = nn.MSELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)


def measure(emb1, emb2, emb3):
    acc=[]
    emb1=emb1.cpu().numpy()
    emb2= emb2.cpu().numpy()
    emb3 = emb3.cpu().numpy()

    for i in range(len(emb1)):
        if np.mean((emb1[i] - emb2[i])**2) < np.mean((emb1[i] - emb3[i])**2):
            acc.append(1)
        else:
            acc.append(0)
    return acc


def get_accuracy(model, loader, letter):

    accuracy = []
    with torch.no_grad():

        for batch_idx, (data1, data2, data3) in enumerate(loader):
            img1 = data1.to(device)
            img1 = img1.unsqueeze(1)
            _, img1_emb = model(img1)

            img2 = data2.to(device)
            img2 = img2.unsqueeze(1)
            _, img2_emb = model(img2)

            img3 = data3.to(device)
            img3 = img3.unsqueeze(1)  #
            _, img3_emb = model(img3)

            accuracy+=measure(img1_emb, img2_emb, img3_emb)


    return np.mean(accuracy)


max_acc=0
max_test_acc=0

for_loop_op = True
save_model = copy.deepcopy(model)
for epoch in range(num_epochs):

    for batch_idx, (data1, data2, data3, sim1,sim2) in enumerate(train_loader2):

        img1 = data1.to(device)
        img1 = img1.unsqueeze(1)

        img2 = data2.to(device)
        img2 = img2.unsqueeze(1)
        _, img2_emb = model(img2)

        img3 = data3.to(device)
        img3 = img3.unsqueeze(1)
        _, img3_emb = model(img3)

        sim1=sim1.to(device)
        sim2=sim2.to(device)
        # Forward pass
        output, img1_emb = model(img1)


        loss = criterion4(output, img1, img1_emb, img2_emb, img3_emb, sim1, sim2, experiment_setting)
        # Backward pass and optimization
        optimizer.zero_grad()
        print("Epoch : ",epoch," batch idx :", batch_idx, "Loss :",loss.item())
        loss.backward()
        optimizer.step()



    val_acc= get_accuracy(model, val_loader, "val start")
    test_acc= get_accuracy(model, test_loader, "test start")


    if val_acc>max_acc:
        max_acc=val_acc
        max_test_acc=test_acc
        save_model=copy.deepcopy(model)


print("Max validation accuracy : ", max_acc)
print("Max evaluation accuracy : ", max_test_acc)
model_path = "./" + str(subject_number) + ".pth"
torch.save(save_model.state_dict(), model_path)
